package com.demo.component.locks.impl;

import com.demo.component.locks.DistributedLock;
import com.demo.uitls.UUIDUtil;
import com.fasterxml.jackson.annotation.JsonAutoDetect;
import com.fasterxml.jackson.annotation.PropertyAccessor;
import com.fasterxml.jackson.databind.ObjectMapper;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.BooleanUtils;
import org.apache.commons.lang.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.data.redis.core.BoundValueOperations;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.stereotype.Component;
import redis.clients.jedis.Jedis;
import redis.clients.util.SafeEncoder;

import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import java.io.Serializable;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * @Author DongXL
 * @Create 2017-09-20 20:50
 */

@Component
@Slf4j
@Aspect
public class RedisDistributedLock {

    @Resource
    private RedisTemplate redisTemplate;

    @PostConstruct
    public void init() {
        redisTemplate.setKeySerializer(new StringRedisSerializer());
        Jackson2JsonRedisSerializer jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
        ObjectMapper objectMapper = new ObjectMapper()
                .setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY)
                .enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(objectMapper);
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
    }


    public boolean setNx(String key, String value, int expireMilliseconds) {
        String status = (String) redisTemplate.execute((RedisCallback<String>) connection -> {
            Jedis jedis = (Jedis) connection.getNativeConnection();
            return jedis.set(key, value, "nx", "px", expireMilliseconds);
        });
        return StringUtils.isNotBlank(status) && "OK".equals(status);
    }


    /**
     * 防止程序异常 ，导致redis中存在未释放的脏数据锁，时间为30分钟
     */
    private static final int KEY_CLEAR_MILLIS = 30 * 60 * 1000;

    @Around("execution(* com.demo..*.*(..)) && @annotation(com.demo.component.locks.DistributedLock)")
    public Object lock(ProceedingJoinPoint pjp) throws Throwable {
        DistributedLock lockParams = getLockInfo(pjp);
        if (lockParams == null) {
            throw new IllegalArgumentException("配置参数错误");
        }
        String synKey = getSynKey(lockParams.key(), pjp);
        if (StringUtils.isBlank(synKey)) {
            throw new IllegalArgumentException("配置参数synKey错误");
        }
        boolean lock = false;
        Object ret = null;
        String threadUuid = UUIDUtil.INSTANCE.randomUUID();
        try {
            // 等待超时时间
            long maxWaitTime = System.currentTimeMillis() + lockParams.maxSleepMills();
            while (!lock) {
                //锁过期时间
                long expiredTime = System.currentTimeMillis() + lockParams.keepMills();
                lock = addLock(synKey, threadUuid, expiredTime, KEY_CLEAR_MILLIS);
                // 得到锁，没有人加过相同的锁
                if (lock) {
                    ret = pjp.proceed();
                    //锁不过期，未获取到锁，重试
                } else if (lockParams.keepMills() <= 0) {
                    // 继续等待获取锁
                    if (lockParams.toWait()) {
                        // 如果超过最大等待时间抛出异常
                        if (lockParams.maxSleepMills() > 0 && System.currentTimeMillis() > maxWaitTime) {
                            throw new TimeoutException("获取锁资源等待超时");
                        }
                        TimeUnit.MILLISECONDS.sleep(lockParams.sleepMills());
                    }
                    // 放弃等待
                    else {
                        break;
                    }
                }
                // 已过期，并且getAndSet后旧的时间戳依然是过期的，可以认为获取到了锁
                else if (expiredAndResetLock(synKey, threadUuid, expiredTime, KEY_CLEAR_MILLIS)) {
                    lock = true;
                    ret = pjp.proceed();
                }
                // 没有得到任何锁
                else {
                    // 继续等待获取锁
                    if (lockParams.toWait()) {
                        // 如果超过最大等待时间抛出异常
                        if (lockParams.maxSleepMills() > 0 && System.currentTimeMillis() > maxWaitTime) {
                            throw new TimeoutException("获取锁资源等待超时");
                        }
                        TimeUnit.MILLISECONDS.sleep(lockParams.sleepMills());
                    }
                    // 放弃等待
                    else {
                        break;
                    }
                }
            }
        } catch (Exception e) {
            log.error("redis lock error",e);
            throw e;
        } finally {
            // 如果获取到了锁，释放锁
            if (lock) {
                releaseLock(synKey, threadUuid);
            }
        }
        return ret;
    }

    /**
     * 获取包括方法参数上的key<br/>
     * redis key
     */
    private String getSynKey(String key, ProceedingJoinPoint pjp) {
        try {
            MethodSignature methodSignature = (MethodSignature) pjp.getSignature();
            Method method = methodSignature.getMethod();
            return parseKey(key, method, pjp.getArgs());
        } catch (Exception e) {
            log.error("获取锁失败", e);
        }
        return null;
    }

    @SuppressWarnings("unchecked")
    private static <T extends Annotation> T getAnnotation(final Class<T> annotationClass, final Annotation[] annotations) {
        if (annotations != null && annotations.length > 0) {
            for (final Annotation annotation : annotations) {
                if (annotationClass.equals(annotation.annotationType())) {
                    return (T) annotation;
                }
            }
        }
        return null;
    }

    /**
     * 获取RedisLock注解信息
     */
    private DistributedLock getLockInfo(ProceedingJoinPoint pjp) {
        try {
            MethodSignature methodSignature = (MethodSignature) pjp.getSignature();
            Method method = methodSignature.getMethod();
            DistributedLock lockInfo = method.getAnnotation(DistributedLock.class);
            return lockInfo;
        } catch (Exception e) {
            log.info("getLockInfo annotation error", e);
        }
        return null;
    }

    private BoundValueOperations<String, LockValue> getOperations(String key) {
        return redisTemplate.boundValueOps(key);
    }

    /**
     * 加锁 原子性
     * <p>
     * Set {@code value} for {@code key}, only if {@code key} does not exist.
     * <p>
     * See http://redis.io/commands/setnx
     *
     * @param key            must not be {@literal null}.
     * @param threadUuid     区分不同线程
     * @param expiredTime    业务过期时间，过期后允许下一个线程处理
     * @param clearKeyMillis 清理key的时间间隔，防止程序异常后存在脏数据的key
     * @return
     */
    private boolean addLock(String key, String threadUuid, long expiredTime, int clearKeyMillis) {
        byte[] value = redisTemplate.getValueSerializer().serialize(new LockValue(threadUuid, expiredTime));
        return setNx(key, SafeEncoder.encode(value), clearKeyMillis);
    }

    /**
     * 释放锁
     *
     * @param key
     * @param uuid
     */
    private void releaseLock(String key, String uuid) {
        LockValue lockValue = getOperations(key).get();
        if (lockValue.getUuid().equals(uuid)) {
            redisTemplate.delete(key);
        }
    }

    /**
     * 锁过期后重新设置锁
     *
     * @param key
     * @param uuid
     * @param expiredTime
     * @param clearKeyMillis
     * @return
     */
    private Boolean expiredAndResetLock(String key, String uuid, long expiredTime, long clearKeyMillis) {
        LockValue lockValue = getOperations(key).get();
        if (lockValue == null) {
            return false;
        }
        return System.currentTimeMillis() > lockValue.getExpiredTime() && getOperations(key).getAndSet(new LockValue(uuid, expiredTime)) != null && BooleanUtils.isTrue(redisTemplate.expire(key, clearKeyMillis, TimeUnit.MILLISECONDS));
    }

    /**
     * 获取缓存的key
     * key 定义在注解上，支持SPEL表达式
     *
     * @param key
     * @param method
     * @param args
     * @return
     */
    private String parseKey(String key, Method method, Object[] args) {
        //获取被拦截方法参数名列表(使用Spring支持类库)
        LocalVariableTableParameterNameDiscoverer u =
                new LocalVariableTableParameterNameDiscoverer();
        String[] paraNameArr = u.getParameterNames(method);
        //使用SPEL进行key的解析
        ExpressionParser parser = new SpelExpressionParser();
        //SPEL上下文
        StandardEvaluationContext context = new StandardEvaluationContext();
        //把方法参数放入SPEL上下文中
        for (int i = 0; i < paraNameArr.length; i++) {
            context.setVariable(paraNameArr[i], args[i]);
        }
        return parser.parseExpression(key).getValue(context, String.class);
    }

    public static class LockValue implements Serializable {

        /**
         * 标识线程唯一性
         */
        private String uuid;

        /**
         * 过期时间
         */
        private Long expiredTime;


        public LockValue() {
        }

        public LockValue(String uuid, Long expiredTime) {
            this.uuid = uuid;
            this.expiredTime = expiredTime;
        }

        public Long getExpiredTime() {
            return expiredTime;
        }

        public void setExpiredTime(Long expiredTime) {
            this.expiredTime = expiredTime;
        }

        public String getUuid() {
            return uuid;
        }

        public void setUuid(String uuid) {
            this.uuid = uuid;
        }
    }

}
